{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/data/users/fyx/.local/python3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import keras\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matchzoo as mz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_pack = mz.datasets.wiki_qa.load_data('train', task='ranking')\n",
    "valid_pack = mz.datasets.wiki_qa.load_data('dev', task='ranking', filter=True)\n",
    "predict_pack = mz.datasets.wiki_qa.load_data('test', task='ranking', filter=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 2118/2118 [00:00<00:00, 6893.66it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit: 100%|██████████| 18841/18841 [00:05<00:00, 3323.25it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 514034.02it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 430982.12it/s]\n",
      "Building VocabularyUnit from a datapack.: 100%|██████████| 1614976/1614976 [00:00<00:00, 2964446.21it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 2118/2118 [00:00<00:00, 4878.11it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 18841/18841 [00:07<00:00, 2572.20it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.preprocessors.DSSMPreprocessor()\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 122/122 [00:00<00:00, 4624.45it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 1115/1115 [00:00<00:00, 2609.60it/s]\n",
      "Processing text_left with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 237/237 [00:00<00:00, 5193.33it/s]\n",
      "Processing text_right with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 2300/2300 [00:00<00:00, 2579.14it/s]\n"
     ]
    }
   ],
   "source": [
    "valid_pack_processed = preprocessor.transform(valid_pack)\n",
    "predict_pack_processed = preprocessor.transform(predict_pack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(loss=mz.losses.RankCrossEntropyLoss(num_neg=4))\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter \"name\" set to DSSM.\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "text_left (InputLayer)          (None, 9644)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "text_right (InputLayer)         (None, 9644)         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "dense_1 (Dense)                 (None, 300)          2893500     text_left[0][0]                  \n",
      "__________________________________________________________________________________________________\n",
      "dense_5 (Dense)                 (None, 300)          2893500     text_right[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense_2 (Dense)                 (None, 300)          90300       dense_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_6 (Dense)                 (None, 300)          90300       dense_5[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_3 (Dense)                 (None, 300)          90300       dense_2[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_7 (Dense)                 (None, 300)          90300       dense_6[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_4 (Dense)                 (None, 128)          38528       dense_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_8 (Dense)                 (None, 128)          38528       dense_7[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dot_1 (Dot)                     (None, 1)            0           dense_4[0][0]                    \n",
      "                                                                 dense_8[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense_9 (Dense)                 (None, 1)            2           dot_1[0][0]                      \n",
      "==================================================================================================\n",
      "Total params: 6,225,258\n",
      "Trainable params: 6,225,258\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.DSSM()\n",
    "model.params['input_shapes'] = preprocessor.context['input_shapes']\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mlp_num_layers'] = 3\n",
    "model.params['mlp_num_units'] = 300\n",
    "model.params['mlp_num_fan_out'] = 128\n",
    "model.params['mlp_activation_func'] = 'relu'\n",
    "model.guess_and_fill_missing_params()\n",
    "model.build()\n",
    "model.compile()\n",
    "model.backend.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_x, pred_y = predict_pack_processed[:].unpack()\n",
    "evaluate = mz.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_generator = mz.PairDataGenerator(train_pack_processed, num_dup=1, num_neg=4, batch_size=64, shuffle=True)\n",
    "len(train_generator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "16/16 [==============================] - 3s 175ms/step - loss: 1.5564\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.45901904458281206 - normalized_discounted_cumulative_gain@5(0): 0.5475429093192175 - mean_average_precision(0): 0.4911167747235369\n",
      "Epoch 2/20\n",
      "16/16 [==============================] - 2s 96ms/step - loss: 1.2718\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4654009025837788 - normalized_discounted_cumulative_gain@5(0): 0.5392725183775516 - mean_average_precision(0): 0.48941174032387735\n",
      "Epoch 3/20\n",
      "16/16 [==============================] - 2s 100ms/step - loss: 1.1539\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4791185363644891 - normalized_discounted_cumulative_gain@5(0): 0.5538418263100713 - mean_average_precision(0): 0.5071942064030672\n",
      "Epoch 4/20\n",
      "16/16 [==============================] - 2s 99ms/step - loss: 1.0995\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4886745351048811 - normalized_discounted_cumulative_gain@5(0): 0.5562701289500901 - mean_average_precision(0): 0.5133703768780384\n",
      "Epoch 5/20\n",
      "16/16 [==============================] - 2s 95ms/step - loss: 1.0674\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4853827809349306 - normalized_discounted_cumulative_gain@5(0): 0.5682223793780434 - mean_average_precision(0): 0.5169149799106053\n",
      "Epoch 6/20\n",
      "16/16 [==============================] - 2s 99ms/step - loss: 1.0479\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.46480280740048857 - normalized_discounted_cumulative_gain@5(0): 0.5281305738905067 - mean_average_precision(0): 0.492201293611383\n",
      "Epoch 7/20\n",
      "16/16 [==============================] - 2s 94ms/step - loss: 1.0274\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4898804500151732 - normalized_discounted_cumulative_gain@5(0): 0.560890920302221 - mean_average_precision(0): 0.5058987750652866\n",
      "Epoch 8/20\n",
      "16/16 [==============================] - 2s 97ms/step - loss: 1.0142\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4941487381639577 - normalized_discounted_cumulative_gain@5(0): 0.5665778188888335 - mean_average_precision(0): 0.5140344974996873\n",
      "Epoch 9/20\n",
      "16/16 [==============================] - 2s 98ms/step - loss: 1.0008\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5069173712527351 - normalized_discounted_cumulative_gain@5(0): 0.5866287176354987 - mean_average_precision(0): 0.5269820047509921\n",
      "Epoch 10/20\n",
      "16/16 [==============================] - 2s 97ms/step - loss: 0.9873\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.49700543338958786 - normalized_discounted_cumulative_gain@5(0): 0.5805879493729443 - mean_average_precision(0): 0.5150557804829956\n",
      "Epoch 11/20\n",
      "16/16 [==============================] - 2s 109ms/step - loss: 0.9750\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5025024116631714 - normalized_discounted_cumulative_gain@5(0): 0.5923504552250592 - mean_average_precision(0): 0.5225994206215725\n",
      "Epoch 12/20\n",
      "16/16 [==============================] - 2s 98ms/step - loss: 0.9644\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5103579714392016 - normalized_discounted_cumulative_gain@5(0): 0.5903011924569881 - mean_average_precision(0): 0.5302960840144384\n",
      "Epoch 13/20\n",
      "16/16 [==============================] - 2s 95ms/step - loss: 0.9534\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5160164432378087 - normalized_discounted_cumulative_gain@5(0): 0.5993189848710705 - mean_average_precision(0): 0.5396924014803761\n",
      "Epoch 14/20\n",
      "16/16 [==============================] - 2s 97ms/step - loss: 0.9442\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5032196943848134 - normalized_discounted_cumulative_gain@5(0): 0.5852110589773137 - mean_average_precision(0): 0.5268703682283354\n",
      "Epoch 15/20\n",
      "16/16 [==============================] - 2s 95ms/step - loss: 0.9336\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5317052606549753 - normalized_discounted_cumulative_gain@5(0): 0.6023052111746768 - mean_average_precision(0): 0.5478608387627374\n",
      "Epoch 16/20\n",
      "16/16 [==============================] - 2s 97ms/step - loss: 0.9239\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.508831525587222 - normalized_discounted_cumulative_gain@5(0): 0.5817387639362221 - mean_average_precision(0): 0.5274733502791022\n",
      "Epoch 17/20\n",
      "16/16 [==============================] - 2s 97ms/step - loss: 0.9144\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5203610304609562 - normalized_discounted_cumulative_gain@5(0): 0.5991831842478874 - mean_average_precision(0): 0.5431669157143841\n",
      "Epoch 18/20\n",
      "16/16 [==============================] - 2s 98ms/step - loss: 0.9071\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5138966254324063 - normalized_discounted_cumulative_gain@5(0): 0.5919938110840292 - mean_average_precision(0): 0.5360874679703793\n",
      "Epoch 19/20\n",
      "16/16 [==============================] - 2s 97ms/step - loss: 0.8990\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.4872557976860261 - normalized_discounted_cumulative_gain@5(0): 0.5698167643783474 - mean_average_precision(0): 0.5167562193088068\n",
      "Epoch 20/20\n",
      "16/16 [==============================] - 2s 98ms/step - loss: 0.8915\n",
      "Validation: normalized_discounted_cumulative_gain@3(0): 0.5492809107350253 - normalized_discounted_cumulative_gain@5(0): 0.6194080901274281 - mean_average_precision(0): 0.5626976311311754\n"
     ]
    }
   ],
   "source": [
    "history = model.fit_generator(train_generator, epochs=20, callbacks=[evaluate], workers=5, use_multiprocessing=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
